package BitonicSort(sortLe) where
import qualified Vector
import List
-- only used for testbench
-- import FileIO

-- Implementation taken from Satnam Singh's Lava work.
-- http://www.xilinx.com/labs/lava/sorter/sorter.htm

--@ \subsubsection{BitonicSort}
--@
--@ \index{BitonicSort@\te{BitonicSort} (package)|textbf}

pairs :: List a -> List (a, a)
pairs Nil = Nil
pairs (Cons _ Nil) = error "pairs"
pairs (Cons x (Cons y xys)) = (x, y) :> pairs xys

unpairs :: List (a, a) -> List a
unpairs Nil = Nil
unpairs (Cons (x, y) xys) = x :> y :> unpairs xys

halve :: List a -> (List a, List a)
halve xs = (take (length xs `div` 2) xs, drop (length xs `div` 2) xs)

unhalve :: (List a, List a) -> List a
unhalve = uncurry append

riffle :: List a -> List a
riffle = unpairs ∘ uncurry zip ∘ halve

unriffle :: List a -> List a
unriffle = unhalve ∘ unzip ∘ pairs

papply :: (a->b, c->d) -> (a, c) -> (b, d)
papply (f, g) (x,y) = (f x, g y)

two :: (List a -> List b) -> List a -> List b
two r = unhalve ∘ papply (r,r) ∘ halve

ilv :: (List a -> List b) -> List a -> List b
ilv r = riffle ∘ two r ∘ unriffle

evens :: ((a, a) -> (b,b)) -> List a -> List b
evens r = unpairs ∘ map r ∘ pairs

bfly :: ((a, a) -> (a, a)) -> Integer -> List a -> List a
bfly _ 0 = id
bfly r n = evens r ∘ ilv (bfly r (n-1))

sndList :: (List a -> List a) -> List a -> List a
sndList f = unhalve ∘ papply (id, f) ∘ halve

sorter :: ((a, a) -> (a, a)) -> Integer -> List a -> List a
sorter _ 0 = id
sorter cmp n = bfly cmp n ∘ sndList reverse ∘ two (sorter cmp (n-1))

cmpSwap :: (a -> a -> Bool) -> (a, a) -> (a, a)
cmpSwap (<=) (x, y) = if x <= y then (x, y) else (y, x)

--@ Sort a list of items given a predicate that compares elements for $<=$.
--@ The \te{sortLe} generates a sorting network using Batcher's bitonic sort.
--@ The length of the list muse be a power of two.
--@ \index{sortLe@\te{sortLe} (function)|textbf}
--@ \begin{libverbatim}
--@ function Vector#(n, a) sortLe(function Bool le(a x1, a x2))
--@ \end{libverbatim}
--sortLe :: (Log n k, Add n 1 m, Add k 1 j, Log m j) => (a -> a -> Bool) -> Vector n a -> Vector n a
sortLe :: (Log n k) => (a -> a -> Bool) -> Vector.Vector n a -> Vector.Vector n a
sortLe (<=) =
  if ( (2 ** (valueOf k)) /= valueOf n) then error "sortLe: length of list not power of 2"
  else (Vector.toVector ∘ sorter (cmpSwap (<=)) (valueOf k) ∘ Vector.toList)



{-
type M = 8
type N = 8
type Data = Bit 8

interface Sort =
    sort :: Vector M Data -> Vector N Data

{-# verilog mkSort #-}
mkSort :: Module Sort
mkSort =
    module
      interface
        sort = Vector.take ∘ sortLe (<=)


{-# verilog testSort #-}
testSort :: Module Empty
testSort =
    module
        s :: Sort <- mkSort
        p :: FilePut (Vector N Data) <- mkFileHexPut "-"
        g :: FileGet (Vector M Data) <- mkFileHexGet "-"
        eof :: Reg Bool <- mkReg False
        rules
         when not eof
          ==> action
                mi :: Maybe (Vector M Data)
                mi <- g.get
                case mi of
                 Just x -> p.put (s.sort x)
                 Nothing -> eof := True
        interface
         { }

-- Sample input data, should all give 0x0807060504030201
-- 0x0102030405060708
-- 0x0807060504030201
-- 0x0701060802030504
-}
